/*
 * Copyright (c) 2021 yedf. All rights reserved.
 * Use of this source code is governed by a BSD-style
 * license that can be found in the LICENSE file.
 */

package dtmimp

import (
    "fmt"
    "strings"
)

// DBSpecial db specific operations
type DBSpecial interface {
    GetPlaceHoldSQL(sql string) string
    GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string
    GetXaSQL(command string, xid string) string
}

var dbSpecials = map[string]DBSpecial{}
var currentDBType = DBTypeMysql

type mysqlDBSpecial struct{}

func (*mysqlDBSpecial) GetPlaceHoldSQL(sql string) string {
    return sql
}

func (*mysqlDBSpecial) GetXaSQL(command string, xid string) string {
    return fmt.Sprintf("xa %s '%s'", command, xid)
}

func (*mysqlDBSpecial) GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string {
    return fmt.Sprintf("insert ignore into %s", tableAndValues)
}

func init() {
    dbSpecials[DBTypeMysql] = &mysqlDBSpecial{}
}

type postgresDBSpecial struct{}

func (*postgresDBSpecial) GetXaSQL(command string, xid string) string {
    return map[string]string{
        "end":      "",
        "start":    "begin",
        "prepare":  fmt.Sprintf("prepare transaction '%s'", xid),
        "commit":   fmt.Sprintf("commit prepared '%s'", xid),
        "rollback": fmt.Sprintf("rollback prepared '%s'", xid),
    }[command]
}

func (*postgresDBSpecial) GetPlaceHoldSQL(sql string) string {
    pos := 1
    parts := []string{}
    b := 0
    for i := 0; i < len(sql); i++ {
        if sql[i] == '?' {
            parts = append(parts, sql[b:i])
            b = i + 1
            parts = append(parts, fmt.Sprintf("$%d", pos))
            pos++
        }
    }
    parts = append(parts, sql[b:])
    return strings.Join(parts, "")
}

func (*postgresDBSpecial) GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string {
    return fmt.Sprintf("insert into %s on conflict ON CONSTRAINT %s do nothing", tableAndValues, pgConstraint)
}
func init() {
    dbSpecials[DBTypePostgres] = &postgresDBSpecial{}
}

// GetDBSpecial get DBSpecial for currentDBType
func GetDBSpecial() DBSpecial {
    return dbSpecials[currentDBType]
}

// SetCurrentDBType set currentDBType
func SetCurrentDBType(dbType string) {
    spec := dbSpecials[dbType]
    PanicIf(spec == nil, fmt.Errorf("unknown db type '%s'", dbType))
    currentDBType = dbType
}

// GetCurrentDBType get currentDBType
func GetCurrentDBType() string {
    return currentDBType
}
